import os
import json
import torch
import h5py
import numpy as np
from scipy.stats import pearsonr
import webdataset as wds
from torch.utils.data import DataLoader
import random

# ---------- 路径 ----------
subj = 1
data_path  = "dataset"
output_dir = f"./subj0{subj}_prior_vae_eye"

# ---------- 2. 用你的方式读 1000 张图 ----------
def my_split_by_node(urls): return urls
test_url = f"{data_path}/wds/subj0{subj}/new_test/0.tar"
test_data = wds.WebDataset(test_url, resampled=False, nodesplitter=my_split_by_node) \
   .shuffle(750, initial=1500, rng=random.Random(42)) \
   .decode("torch") \
   .rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy",
           olds_behav="olds_behav.npy") \
   .to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"])
test_dl  = DataLoader(test_data, batch_size=3000, shuffle=False,
                      drop_last=False, pin_memory=True, num_workers=4, prefetch_factor=2)

h5_file = h5py.File(os.path.join(data_path, "coco_images_224_float16.hdf5"), 'r')
images = h5_file['images']

clip_emb_file = h5py.File(f'{data_path}/clip_embeddings.hdf5', 'r')
clip_embeddings = clip_emb_file['embeddings']  # (N, 257, 768)

test_image = None
image_idx_1000 = None
test_clip_emb = None

with torch.no_grad():
    for behav, _, _, _ in test_dl:
        image_idx = behav[:, 0, 0].cpu().long()
        unique_image, sort_indices = torch.unique(image_idx, return_inverse=True)
        for im in unique_image:
            locs = torch.where(im == image_idx)[0]
            if len(locs) == 1:
                locs = locs.repeat(3)
            elif len(locs) == 2:
                locs = locs.repeat(2)[:3]
            assert len(locs) == 3

            if test_image is None:
                test_image = torch.Tensor(images[im][None])
                test_clip_emb = torch.Tensor(clip_embeddings[im][None])
            else:
                test_image = torch.vstack((test_image, torch.Tensor(images[im][None])))
                test_clip_emb = torch.vstack((test_clip_emb, torch.Tensor(clip_embeddings[im][None])))
        break

assert test_image.shape[0] == 1000, f"expect 1000 images, got {test_image.shape[0]}"

# ---------- 4. 保存 ----------
torch.save(test_image,              os.path.join(output_dir, "test_images_1000.pt"))
torch.save(test_clip_emb,              os.path.join(output_dir, "test_clip_1000.pt"))

print("✅ 完成：已保存 1000 张图")